{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fdb5ffea",
   "metadata": {},
   "source": [
    "## Notebook 4 - Analysis of MS spectra"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa724fbe",
   "metadata": {},
   "source": [
    "This notebook identifies products based on the database of MS2 spectra generated previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95723f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ../common.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f1026e25",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Determination of memory status is not supported on this \n",
      " platform, measuring for memoryleaks will never fail\n"
     ]
    }
   ],
   "source": [
    "from pyopenms import MSExperiment, MzMLFile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80f5bdc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us check we are not missing any screening data file\n",
    "\n",
    "# Enzymes for which we have screening data\n",
    "eids_screen = set()\n",
    "\n",
    "files = glob.glob(f'../data/MS_2/Mix*/mzML/*.mzML')\n",
    "for file in files:\n",
    "    tmp = os.path.basename(file)\n",
    "    tmp = tmp.split(\"_\")[0]\n",
    "    eids_screen.add(tmp)\n",
    "\n",
    "\n",
    "for mix_no in range(1, 13):\n",
    "    files = glob.glob(os.path.join(f'../data/MS_2/Mix{mix_no}/mzML/*.mzML'))\n",
    "    \n",
    "    eids_present = set()\n",
    "    for file in files:\n",
    "        tmp = os.path.basename(file)\n",
    "        tmp = tmp.split(\"_\")[0]\n",
    "        eids_present.add(tmp)\n",
    "    \n",
    "    missing_eids = eids_screen - eids_present\n",
    "    if len(missing_eids) > 0:\n",
    "        print(f'Mix {mix_no}')\n",
    "        print(f'Enzymes missing: {missing_eids}\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "70a5f392",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We already prepared the reference database\n",
    "df_database = pd.read_pickle(filepath_results + 'MS2_database_shifts_162_320_VB_clean.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cfd628be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_c_glucoside(mz_peaks, i_peaks, charge, mz_precursor, adduct_type, tol=0.05, i_threshold=0.10):\n",
    "    shifts = np.array([120.042, 90.032])\n",
    "    if charge == 0:\n",
    "        if adduct_type in ['M+Glu+H','M+Glu+Na','M+Glu+NH4','M+Glu+ACN+H','M+2Glu+H','M+2Glu+Na','M+2Glu+NH4','M+2Glu+ACN+H']:\n",
    "            pass\n",
    "        elif adduct_type in ['M+Glu+2H','M+2Glu+2H']:\n",
    "            shifts = shifts / 2.\n",
    "    elif charge ==1:\n",
    "        if adduct_type in ['M+Glu+H', 'M+2Glu+H']:\n",
    "            pass\n",
    "        elif adduct_type in ['M+Glu+2H','M+Glu+Na','M+Glu+NH4','M+Glu+ACN+H','M+2Glu+2H','M+2Glu+Na','M+2Glu+NH4','M+2Glu+ACN+H']:\n",
    "            shifts = shifts / 2.\n",
    "            \n",
    "    i_peaks = i_peaks / max(i_peaks)\n",
    "    \n",
    "    for shift in shifts:\n",
    "        test = np.where(np.abs(mz_precursor - mz_peaks - shift) < tol)[0]\n",
    "        for i in test:\n",
    "            if i_peaks[i] > i_threshold:\n",
    "                return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ac39e267",
   "metadata": {},
   "outputs": [],
   "source": [
    "def integrate_MS1_row(row, exp, verbose):\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Product {row['Name']} identified with score {row['CosineScore']}\")\n",
    "\n",
    "    rt_hit = row['RT']\n",
    "    mz_hit = row['PrecursorMZ']\n",
    "\n",
    "    AUC = []\n",
    "    t_integrate = []\n",
    "    mz_ub = mz_hit*(1+ppm_integrate/1e6)\n",
    "    mz_lb = mz_hit*(1-ppm_integrate/1e6)\n",
    "    rt_lb = rt_hit - rt_semiwindow_integrate\n",
    "    rt_ub = rt_hit + rt_semiwindow_integrate\n",
    "    for spec in exp:\n",
    "        if spec.getMSLevel() == 1:\n",
    "            rt = spec.getRT()\n",
    "            if rt < rt_lb:\n",
    "                continue\n",
    "            if rt > rt_ub:\n",
    "                break\n",
    "\n",
    "\n",
    "            # Let us integrate along the m/z dimension to obtain the AUC for that time slice\n",
    "            mz_peaks, i_peaks = spec.get_peaks()\n",
    "            idx_bool = (mz_peaks > mz_lb) & (mz_peaks < mz_ub)\n",
    "            if any(idx_bool):\n",
    "                i_peaks_integrate = np.concatenate(([0.], i_peaks[idx_bool], [0.]))\n",
    "                mz_peaks_integrate = np.concatenate(([mz_lb], mz_peaks[idx_bool], [mz_ub]))\n",
    "                AUC_i = np.trapz(y=i_peaks_integrate, x=mz_peaks_integrate)\n",
    "\n",
    "                # Let us compute the baseline AUC to be substracted\n",
    "                #idx_where = np.where(idx_bool)[0]\n",
    "\n",
    "                #idx_first = idx_where[0]\n",
    "                #idx_last = idx_where[-1]\n",
    "                #i_peaks_integrate = np.array([0., i_peaks[idx_first], i_peaks[idx_last], 0.])\n",
    "                #mz_peaks_integrate = np.array([mz_lb, mz_peaks[idx_first], mz_peaks[idx_last], mz_ub])\n",
    "                #AUC_base = np.trapz(y=i_peaks_integrate, x=mz_peaks_integrate)\n",
    "                AUC_base = 0.\n",
    "\n",
    "                t_integrate.append(rt)\n",
    "                AUC.append(AUC_i-AUC_base)\n",
    "\n",
    "    # Let us integrate along the time dimension\n",
    "    t_integrate = np.concatenate(([rt_lb], t_integrate, [rt_ub]))\n",
    "    AUC = np.concatenate(([0.], AUC, [0.]))\n",
    "    AUC = np.trapz(y=AUC, x=t_integrate)\n",
    "        \n",
    "    return AUC\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32078b63",
   "metadata": {},
   "source": [
    "We record all hits with `CosineScore` > 0.5. Then, we can plot the number of hits vs. the desired cutoff (say CosineScore = 0.85)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d16414e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "score_threshold = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe825e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_mix(mix_no, rt_dict={}, verbose=False):\n",
    "    \n",
    "    ### Prepare reference\n",
    "    df_mix = df_database[df_database['Mix'] == mix_no]\n",
    "\n",
    "    ref_name = list(df_mix['Name'])\n",
    "    ref_smiles = list(df_mix['CSMILES'])\n",
    "    ref_substrates_charges = np.array(df_mix['M_charge'])\n",
    "    colnames = [a for a in df_mix.columns if 'M+' in a]\n",
    "    ref_precursors_mz = np.array(df_mix[colnames])\n",
    "    ref_fragments_mz = list(df_mix['mz'])\n",
    "    ref_fragments_i = list(df_mix['relint'])\n",
    "    ref_mona_no = list(df_mix['DB#'])\n",
    "    \n",
    "    ###\n",
    "    \n",
    "    files = glob.glob(f'../data/MS_2/Mix{mix_no}/mzML/*.mzML')\n",
    "    \n",
    "    df = pd.DataFrame(columns = ['File', 'Mix', 'Enzyme_id', 'Name', 'CSMILES', 'PrecursorMZ', 'Adduct', 'RT', \n",
    "                                 'CosineScore', 'C-gly', 'AUC'])\n",
    "    c = 0 \n",
    "    for filename in files:\n",
    "        c+=1\n",
    "        file = os.path.basename(filename)\n",
    "        print(f'\\n{file} (file {c} out of {len(files)})')\n",
    "\n",
    "        enzyme_id = file.split('_')[0]\n",
    "        \n",
    "        exp = MSExperiment()\n",
    "        MzMLFile().load(filename, exp)\n",
    "\n",
    "\n",
    "        ## SUBSET EXPERIMENTAL MS2 SPECTRA\n",
    "\n",
    "        list_hit_name = []\n",
    "        list_hit_mzprecursor = []\n",
    "        list_hit_rt = []\n",
    "        list_hit_score = []\n",
    "        list_hit_smiles = []\n",
    "        list_adducts = []\n",
    "        list_c_glu = []\n",
    "        list_mona_no = []\n",
    "        for spec in exp:\n",
    "            if spec.getMSLevel() == 2:\n",
    "                mz_precursor = spec.getPrecursors()[0].getMZ()\n",
    "                mz_peaks, i_peaks = spec.get_peaks()\n",
    "                rt = spec.getRT() # minutes\n",
    "                \n",
    "                test1 = ((np.abs(mz_precursor - ref_precursors_mz)/mz_precursor)*1e6) < ppm # boolean 2D matrix\n",
    "                \n",
    "                score_best = score_threshold\n",
    "                for i,j in zip(*np.where(test1)):\n",
    "                    \n",
    "                    # We also test for retention time\n",
    "                    name = ref_name[i]\n",
    "                    ref_rt = rt_dict.get(name, None)\n",
    "                    if ref_rt is None: # If not have a reference RT for a compound, assume any value is possible\n",
    "                        test2 = True\n",
    "                    else:\n",
    "                        test2 = np.abs(rt - ref_rt) < rt_semiwindow_cutoff # boolean scalar\n",
    "                    \n",
    "                    if test2:\n",
    "                        score = cosine_greedy_score(mz_peaks, i_peaks, ref_fragments_mz[i], ref_fragments_i[i])\n",
    "                        if score > score_best:\n",
    "                            score_best = score\n",
    "                            I = i\n",
    "                            J = j\n",
    "                \n",
    "                if score_best > score_threshold:\n",
    "                    list_hit_mzprecursor.append(mz_precursor)\n",
    "                    \n",
    "                    list_hit_rt.append(rt)\n",
    "                    list_hit_score.append(score_best)\n",
    "                    list_hit_smiles.append(ref_smiles[I])\n",
    "                    list_mona_no.append(ref_mona_no[I])\n",
    "                    adduct_type = colnames[J]\n",
    "                    list_adducts.append(adduct_type)\n",
    "                    name = ref_name[I]\n",
    "                    if '2Glu' in adduct_type:\n",
    "                        name += ' (2GLC)'\n",
    "                    list_hit_name.append(name)\n",
    "                    \n",
    "                    # We consider that the hit may be a C-glycoside\n",
    "                    # if it has a -120 or -90 Da neutral loss\n",
    "                    charge = ref_substrates_charges[I]\n",
    "                    is_c_glu = check_c_glucoside(mz_peaks, i_peaks, charge, mz_precursor, adduct_type)\n",
    "                    list_c_glu.append(is_c_glu)\n",
    "                    \n",
    "        \n",
    "        df_temp = pd.DataFrame({\n",
    "            'File': filename,\n",
    "            'Mix': mix_no,\n",
    "            'Enzyme_id': enzyme_id,\n",
    "            'Name': list_hit_name,\n",
    "            'CSMILES': list_hit_smiles,\n",
    "            'PrecursorMZ': np.round(list_hit_mzprecursor,4),\n",
    "            'Adduct': list_adducts,\n",
    "            'RT': np.round(list_hit_rt, 2),\n",
    "            'CosineScore': np.round(list_hit_score,3),\n",
    "            'C-gly': list_c_glu,\n",
    "            'MoNA_DB#': list_mona_no,\n",
    "        })\n",
    "\n",
    "        # Let us now remove duplicate products based on their name\n",
    "        # The sorting ensures we keep the highest score when dropping\n",
    "        df_temp.sort_values(by=['CosineScore'], ascending=[False], inplace=True) \n",
    "        df_temp = df_temp.drop_duplicates(subset=['Name'], keep='first') # We will keep the highest score for each Name\n",
    "        \n",
    "        print(f'{len(df_temp)} products identified')\n",
    "        \n",
    "\n",
    "        ## INTEGRATE IN THE MS1 CHANNEL (2D integration)\n",
    "        \n",
    "        list_hit_AUC = df_temp.apply(integrate_MS1_row, args=(exp, False), axis=1)\n",
    "\n",
    "        df_temp['AUC'] = np.round(list_hit_AUC,0)\n",
    "        df = pd.concat([df,df_temp], ignore_index=True)\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d16a364",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "rt_dict = {}\n",
    "i = 1\n",
    "while True:\n",
    "    df = pd.DataFrame()\n",
    "    print(f'\\nIteration {i}')\n",
    "\n",
    "    for mix_no in range(1,13,1):\n",
    "        print(f'\\nProcessing mix {mix_no}')\n",
    "        df_tmp = process_mix(mix_no, rt_dict)\n",
    "        if not df_tmp.empty:\n",
    "            df = pd.concat([df,df_tmp])\n",
    "        \n",
    "    df.to_pickle(f'./tmp/Allmixes_results_cosine_iteration_{i}.pkl')\n",
    "    df.to_csv(f'./tmp/Allmixes_results_cosine_iteration_{i}.csv', index=False)\n",
    "    \n",
    "    # Let us update the rt_dict for the next iteration.\n",
    "    # To compute the dictionary, we get the median RT for each substrate observed across all experiments\n",
    "    # We only use reliable hits for computing the dictionary (i.e., CosineScore>=0.85)\n",
    "    df_tmp = df[df['CosineScore']>=0.85]\n",
    "    rt_dict_new = dict(df_tmp.groupby(['Name'])['RT'].median())\n",
    "    \n",
    "    # Let us only run 1 iteration\n",
    "    if (rt_dict_new == rt_dict) | i == 1:\n",
    "        break\n",
    "        \n",
    "    rt_dict = rt_dict_new\n",
    "    i+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9700fa6f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "project_gt",
   "language": "python",
   "name": "project_gt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
